{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 升级 TF1 为 TF2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-01-06 16:23:07.117814: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2024-01-06 16:23:07.160960: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-01-06 16:23:07.161006: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-01-06 16:23:07.161048: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-01-06 16:23:07.169897: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2024-01-06 16:23:07.170731: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-01-06 16:23:08.330236: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "try:\n",
    "    tf1 = tf.compat.v1\n",
    "except (ImportError, AttributeError):\n",
    "    tf1 = tf\n",
    "\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_dir = \".temp\" # 缓存目录\n",
    "# log_dir = \"/media/pc/data/lxw/ai/tasks/logs/tf1-tf2\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 检查点迁移\n",
    "\n",
    "定义辅助函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_checkpoint(save_path, print_value=True, tab_size=10):\n",
    "    \"\"\"打印检查点信息\"\"\"\n",
    "    shape_size = max(tab_size, 20)\n",
    "    dtype_size = max(tab_size, 10)\n",
    "    reader = tf.train.load_checkpoint(save_path)\n",
    "    shapes = reader.get_variable_to_shape_map()\n",
    "    dtypes = reader.get_variable_to_dtype_map()\n",
    "    print(f\"检查点: {save_path}\")\n",
    "    tt = \"key\".ljust(tab_size)\n",
    "    tt += \"\\tshape\".ljust(shape_size)\n",
    "    tt += \"\\tdtype\".ljust(dtype_size)\n",
    "    tt += \"\\tvalue\".ljust(tab_size)\n",
    "    print(tt)\n",
    "    print(\"=\"*tab_size*7)\n",
    "    for key in shapes:\n",
    "        tt = f\"{key}\".ljust(tab_size)\n",
    "        tt += f\"\\t{shapes[key]}\".ljust(shape_size)\n",
    "        tt += f\"\\t{dtypes[key].name}\".ljust(dtype_size)\n",
    "        if print_value:\n",
    "            tt += f\"\\t{reader.get_tensor(key)}\".ljust(max(tab_size, 10))\n",
    "        print(tt)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "先看 TF1 的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "检查点: .temp/tf1-ckpt\n",
      "key       \tshape              \tdtype    \tvalue    \n",
      "======================================================================\n",
      "scoped/c  \t[]                 \tuint8    \t3        \n",
      "b         \t[]                 \tuint8    \t2        \n",
      "a         \t[]                 \tfloat32  \t1.0      \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-01-06 16:23:09.741226: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
      "Skipping registering GPU devices...\n",
      "2024-01-06 16:23:09.762581: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:382] MLIR V1 optimization pass is not enabled\n"
     ]
    }
   ],
   "source": [
    "with tf.Graph().as_default() as g:\n",
    "    a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n",
    "                         initializer=tf1.zeros_initializer())\n",
    "    b = tf1.get_variable('b', shape=[], dtype=tf.uint8, \n",
    "                         initializer=tf1.zeros_initializer())\n",
    "    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.uint8, \n",
    "                         initializer=tf1.zeros_initializer())\n",
    "    with tf1.Session() as sess:\n",
    "        saver = tf1.train.Saver()\n",
    "        sess.run(a.assign(1))\n",
    "        sess.run(b.assign(2))\n",
    "        sess.run(c.assign(3))\n",
    "        saver.save(sess, f'{temp_dir}/tf1-ckpt')\n",
    "print_checkpoint(f'{temp_dir}/tf1-ckpt')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TF2 的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "检查点: .temp/tf2-ckpt-1\n",
      "key                             \tshape                          \tdtype                          \tvalue                          \n",
      "================================================================================================================================================================================================================================\n",
      "variables/2/.ATTRIBUTES/VARIABLE_VALUE\t[]                             \tfloat32                        \t7.0                            \n",
      "variables/1/.ATTRIBUTES/VARIABLE_VALUE\t[]                             \tfloat32                        \t6.0                            \n",
      "variables/0/.ATTRIBUTES/VARIABLE_VALUE\t[]                             \tfloat32                        \t5.0                            \n",
      "save_counter/.ATTRIBUTES/VARIABLE_VALUE\t[]                             \tint64                          \t1                              \n",
      "_CHECKPOINTABLE_OBJECT_GRAPH    \t[]                             \tstring                         \tb\"\\n%\\n\\r\\x08\\x01\\x12\\tvariables\\n\\x10\\x08\\x02\\x12\\x0csave_counter*\\x02\\x08\\x01\\n\\x19\\n\\x05\\x08\\x03\\x12\\x010\\n\\x05\\x08\\x04\\x12\\x011\\n\\x05\\x08\\x05\\x12\\x012*\\x02\\x08\\x01\\nM\\x12G\\n\\x0eVARIABLE_VALUE\\x12\\x0csave_counter\\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nA\\x12;\\n\\x0eVARIABLE_VALUE\\x12\\x01a\\x1a&variables/0/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nA\\x12;\\n\\x0eVARIABLE_VALUE\\x12\\x01b\\x1a&variables/1/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nI\\x12C\\n\\x0eVARIABLE_VALUE\\x12\\tscoped2/c\\x1a&variables/2/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\"\n"
     ]
    }
   ],
   "source": [
    "a = tf.Variable(5.0, name='a')\n",
    "b = tf.Variable(6.0, name='b')\n",
    "with tf.name_scope('scoped2'):\n",
    "    c = tf.Variable(7.0, name='c')\n",
    "ckpt = tf.train.Checkpoint(variables=[a, b, c])\n",
    "save_path_v2 = ckpt.save(f'{temp_dir}/tf2-ckpt')\n",
    "print_checkpoint(save_path_v2, tab_size=32)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "基于名称的检查点中的键是变量的名称。基于对象的检查点中的键指向从根对象到变量的路径。\n",
    "```\n",
    "\n",
    "查看 `tf2-ckpt` 中的键，可以看出它们全部指向每个变量的对象路径。\n",
    "\n",
    "仔细研究下面的打印信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root type = Checkpoint\n",
      "root.variables = ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>])\n",
      "root.variables[0] = <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>\n"
     ]
    }
   ],
   "source": [
    "a = tf.Variable(0.)\n",
    "b = tf.Variable(0.)\n",
    "c = tf.Variable(0.)\n",
    "root = ckpt = tf.train.Checkpoint(variables=[a, b, c])\n",
    "print(\"root type =\", type(root).__name__)\n",
    "print(\"root.variables =\", root.variables)\n",
    "print(\"root.variables[0] =\", root.variables[0])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "尝试使用下面的代码段，看看检查点键如何随对象结构变化："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "检查点: .temp/root-tf2-ckpt-1\n",
      "key                      \tshape                   \tdtype                   \tvalue                   \n",
      "===============================================================================================================================================================================\n",
      "v/b/.ATTRIBUTES/VARIABLE_VALUE\t[]                      \tfloat32                 \t0.0                     \n",
      "v/a/.ATTRIBUTES/VARIABLE_VALUE\t[]                      \tfloat32                 \t0.0                     \n",
      "module/d/.ATTRIBUTES/VARIABLE_VALUE\t[]                      \tfloat32                 \t0.0                     \n",
      "save_counter/.ATTRIBUTES/VARIABLE_VALUE\t[]                      \tint64                   \t1                       \n",
      "c/.ATTRIBUTES/VARIABLE_VALUE\t[]                      \tfloat32                 \t0.0                     \n",
      "_CHECKPOINTABLE_OBJECT_GRAPH\t[]                      \tstring                  \tb\"\\n0\\n\\x05\\x08\\x01\\x12\\x01c\\n\\n\\x08\\x02\\x12\\x06module\\n\\x05\\x08\\x03\\x12\\x01v\\n\\x10\\x08\\x04\\x12\\x0csave_counter*\\x02\\x08\\x01\\n>\\x128\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a\\x1cc/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n\\x0b\\n\\x05\\x08\\x05\\x12\\x01d*\\x02\\x08\\x01\\n\\x12\\n\\x05\\x08\\x06\\x12\\x01a\\n\\x05\\x08\\x07\\x12\\x01b*\\x02\\x08\\x01\\nM\\x12G\\n\\x0eVARIABLE_VALUE\\x12\\x0csave_counter\\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nE\\x12?\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a#module/d/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a\\x1ev/a/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a\\x1ev/b/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\"\n"
     ]
    }
   ],
   "source": [
    "module = tf.Module()\n",
    "module.d = tf.Variable(0.)\n",
    "test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, \n",
    "                                c=c,\n",
    "                                module=module)\n",
    "test_ckpt_path = test_ckpt.save(f'{temp_dir}/root-tf2-ckpt')\n",
    "print_checkpoint(test_ckpt_path, tab_size=25)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面是对不同模型使用相同检查点的示例。\n",
    "\n",
    "1. 使用 `tf1.train.Saver` 保存 TF1 检查点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "检查点: .temp/tf1-ckpt\n",
      "key       \tshape              \tdtype    \tvalue    \n",
      "======================================================================\n",
      "scoped/c  \t[]                 \tuint8    \t3        \n",
      "b         \t[]                 \tuint8    \t2        \n",
      "a         \t[]                 \tfloat32  \t1.0      \n"
     ]
    }
   ],
   "source": [
    "with tf.Graph().as_default() as g:\n",
    "    a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n",
    "                         initializer=tf1.zeros_initializer())\n",
    "    b = tf1.get_variable('b', shape=[], dtype=tf.uint8, \n",
    "                         initializer=tf1.zeros_initializer())\n",
    "    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.uint8, \n",
    "                         initializer=tf1.zeros_initializer())\n",
    "    with tf1.Session() as sess:\n",
    "        saver = tf1.train.Saver()\n",
    "        sess.run(a.assign(1))\n",
    "        sess.run(b.assign(2))\n",
    "        sess.run(c.assign(3))\n",
    "        saver.save(sess, f'{temp_dir}/tf1-ckpt')\n",
    "\n",
    "print_checkpoint(f'{temp_dir}/tf1-ckpt')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 使用 `tf.compat.v1.Saver` 在 Eager 模式下加载检查点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "加载后的值 [a, b, c]:  [1.0, 2, 3]\n",
      "检查点: .temp/tf1-ckpt-saved-in-eager\n",
      "key       \tshape              \tdtype    \tvalue    \n",
      "======================================================================\n",
      "scoped/c  \t[]                 \tuint8    \t3        \n",
      "b         \t[]                 \tuint8    \t2        \n",
      "a         \t[]                 \tfloat32  \t1.0      \n"
     ]
    }
   ],
   "source": [
    "a = tf.Variable(0, name=\"a\", dtype=tf.float32)\n",
    "b = tf.Variable(0, name=\"b\", dtype=tf.uint8)\n",
    "with tf.name_scope('scoped'):\n",
    "    c = tf.Variable(0, name='c', dtype=tf.uint8)\n",
    "\n",
    "# 在 TF2 中删除集合后，必须将变量列表传递给 Saver 对象：\n",
    "saver = tf1.train.Saver(var_list=[a, b, c])\n",
    "saver.restore(sess=None, save_path=f'{temp_dir}/tf1-ckpt')\n",
    "print(f\"加载后的值 [a, b, c]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}]\")\n",
    "# Saving 也可以立即执行(sess 必须为 None)。\n",
    "path = saver.save(sess=None, save_path=f'{temp_dir}/tf1-ckpt-saved-in-eager')\n",
    "print_checkpoint(path)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用 TF2 API `tf.train.Checkpoint` 加载检查点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "变量名称: \n",
      "\ta.name = a:0\n",
      "\tb.name = b:0\n",
      "\tc.name = scoped/c:0\n",
      "\tc_2.name = scoped/c:0\n",
      "加载后的值 [a, b, c, c_2]:  [1.0, 2, 3, 3]\n"
     ]
    }
   ],
   "source": [
    "a = tf.Variable(0, name=\"a\", dtype=tf.float32)\n",
    "b = tf.Variable(0, name=\"b\", dtype=tf.uint8)\n",
    "with tf.name_scope('scoped'):\n",
    "    c = tf.Variable(0, name='c', dtype=tf.uint8)\n",
    "\n",
    "# Without the name_scope, name=\"scoped/c\" works too:\n",
    "c_2 = tf.Variable(0, name='scoped/c', dtype=tf.uint8)\n",
    "\n",
    "print(\"变量名称: \")\n",
    "print(f\"\\ta.name = {a.name}\")\n",
    "print(f\"\\tb.name = {b.name}\")\n",
    "print(f\"\\tc.name = {c.name}\")\n",
    "print(f\"\\tc_2.name = {c_2.name}\")\n",
    "\n",
    "# Restore the values with tf.train.Checkpoint\n",
    "ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])\n",
    "ckpt.restore(f'{temp_dir}/tf1-ckpt')\n",
    "print(f\"加载后的值 [a, b, c, c_2]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SavedModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "def remove_dir(path):\n",
    "    try:\n",
    "        shutil.rmtree(path)\n",
    "    except:\n",
    "        ..."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义简单运算："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_two(x):\n",
    "    return x + 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TensorFlow 1：保存和导出 SavedModel\n",
    "\n",
    "在 TensorFlow 1 中，使用 `tf.compat.v1.saved_model.Builder`、`tf.compat.v1.saved_model.simple_save` 和 `tf.estimator.Estimator.export_saved_model` API 来构建、保存及导出 TensorFlow 计算图和会话。\n",
    "\n",
    "1. 使用 SavedModelBuilder 将计算图保存为 SavedModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "结果为 5.0\n"
     ]
    }
   ],
   "source": [
    "model_dir = f\"{temp_dir}/saved-model-builder\"\n",
    "remove_dir(model_dir)\n",
    "\n",
    "with tf.Graph().as_default() as g:\n",
    "    x = tf1.placeholder(tf.float32, shape=[])\n",
    "    y = add_two(x)\n",
    "    with tf1.Session() as sess:\n",
    "        print(f\"结果为 {sess.run(y, {x: 3.})}\")\n",
    "\n",
    "        # 使用 SavedModelBuilder 保持\n",
    "        builder = tf1.saved_model.Builder(model_dir)\n",
    "        sig_def = tf1.saved_model.predict_signature_def(\n",
    "            inputs={'input': x},\n",
    "            outputs={'output': y}\n",
    "        )\n",
    "        builder.add_meta_graph_and_variables(\n",
    "            sess=sess,\n",
    "            tags=[tf1.saved_model.tag_constants.SERVING],\n",
    "            signature_def_map={\n",
    "                tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def\n",
    "        })\n",
    "        builder.save()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 为应用构建 SavedModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "结果为 5.0\n"
     ]
    }
   ],
   "source": [
    "remove_dir(f\"{temp_dir}/simple-save\")\n",
    "\n",
    "with tf.Graph().as_default() as g:\n",
    "    x = tf1.placeholder(tf.float32, shape=[])\n",
    "    y = add_two(x)\n",
    "    with tf1.Session() as sess:\n",
    "        print(f\"结果为 {sess.run(y, {x: 3.})}\")\n",
    "        tf1.saved_model.simple_save(\n",
    "            sess, f\"{temp_dir}/simple-save\",\n",
    "            inputs={'input': x},\n",
    "            outputs={'output': y}\n",
    "        )"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TensorFlow 2：保存和导出 SavedModel\n",
    "\n",
    "保存并导出使用 `tf.Module` 定义的 SavedModel\n",
    "\n",
    "要在 TensorFlow 2 中导出模型，必须定义 `tf.Module` 或 `tf.keras.Model` 来保存模型的所有变量和函数。随后，可以调用 `tf.saved_model.save` 来创建 SavedModel。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(tf.Module):\n",
    "    @tf.function\n",
    "    def __call__(self, x):\n",
    "        return add_two(x)\n",
    "    \n",
    "model = MyModel()\n",
    "\n",
    "@tf.function\n",
    "def serving_default(x):\n",
    "    return {\"output\": model(x)}\n",
    "\n",
    "signature_function = serving_default.get_concrete_function(\n",
    "    tf.TensorSpec(shape=[], dtype=tf.float32)\n",
    ")\n",
    "tf.saved_model.save(\n",
    "    model, f\"{temp_dir}/tf2-save\",\n",
    "    signatures={\n",
    "        tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_function\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保存并导出使用 Keras 定义的 SavedModel\n",
    "\n",
    "用于保存和导出的 Keras API（`Model.save` 或 `tf.keras.models.save_model`）可以从 `tf.keras.Model` 导出 SavedModel。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "inp = tf.keras.Input(3)\n",
    "out = add_two(inp)\n",
    "model = tf.keras.Model(inputs=inp, outputs=out)\n",
    "\n",
    "@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])\n",
    "def serving_default(input):\n",
    "    return {'output': model(input)}\n",
    "\n",
    "model.save(\n",
    "    f\"{temp_dir}/keras-model\", save_format='tf', \n",
    "    signatures={tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_default}\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载 SavedModel\n",
    "\n",
    "### TensorFlow 1：使用 `tf.saved_model.load` 加载 SavedModel\n",
    "\n",
    "在 TensorFlow 1 中，可以使用 `tf.saved_model.load` 将 SavedModel 直接导入当前计算图和会话。可以在张量输入和输出名称上调用 `Session.run`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "加载 .temp/saved-model-builder\n",
      "5.0 => 7.0\n",
      "加载 .temp/simple-save\n",
      "5.0 => 7.0\n",
      "加载 .temp/keras-model\n",
      "5.0 => 7.0\n"
     ]
    }
   ],
   "source": [
    "def load_tf1(path, x):\n",
    "    print(f\"加载 {path}\")\n",
    "    with tf.Graph().as_default() as g:\n",
    "        with tf1.Session() as sess:\n",
    "            meta_graph = tf1.saved_model.load(sess, [\"serve\"], path)\n",
    "            sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
    "            input_name = sig_def.inputs['input'].name\n",
    "            output_name = sig_def.outputs['output'].name\n",
    "            print(x, '=>', sess.run(output_name, feed_dict={input_name: x}))\n",
    "load_tf1(f'{temp_dir}/saved-model-builder', 5.)\n",
    "load_tf1(f'{temp_dir}/simple-save', 5.)\n",
    "load_tf1(f'{temp_dir}/keras-model', 5.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TensorFlow 2：加载使用 tf.saved_model 保存的模型\n",
    "\n",
    "在 TensorFlow 2 中，对象会加载到存储变量和函数的 Python 对象中。这与从 TensorFlow 1 保存的模型兼容。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_tf2(path, x):\n",
    "    print(f\"加载 {path}\")\n",
    "    loaded = tf.saved_model.load(path)\n",
    "    sig_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n",
    "    out = loaded.signatures[sig_key](tf.constant(x))['output']\n",
    "    print(x, '=>', out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "加载 .temp/saved-model-builder\n",
      "5.0 => tf.Tensor(7.0, shape=(), dtype=float32)\n",
      "加载 .temp/simple-save\n",
      "5.0 => tf.Tensor(7.0, shape=(), dtype=float32)\n",
      "加载 .temp/tf2-save\n",
      "5.0 => tf.Tensor(7.0, shape=(), dtype=float32)\n",
      "加载 .temp/keras-model\n",
      "5.0 => tf.Tensor(7.0, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "load_tf2(f'{temp_dir}/saved-model-builder', 5.)\n",
    "load_tf2(f'{temp_dir}/simple-save', 5.)\n",
    "load_tf2(f'{temp_dir}/tf2-save', 5.)\n",
    "load_tf2(f'{temp_dir}/keras-model', 5.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
